#include "nanovoid_app.h"
#include <cstring>

NanovoidOneBackNormal::NanovoidOneBackNormal(int _Nx, int _Ny, ParameterSet _p)
        :
        Nx(_Nx), Ny(_Ny), size(_Nx * _Ny),
        p(_p), dp(0.0) {
            this->value_table_size = this->size * this->size * this->n_channels;
            this->num_items = this->size; 
            this->old_v = new valueType[this->value_table_size];
            this->new_v = new valueType[this->value_table_size];
        }


NanovoidOneBackNormal::~NanovoidOneBackNormal() {
    delete this->old_v;
    delete this->new_v;
}

void NanovoidOneBackNormal::grab_vals(uint item, valueType *value_table, valueType *vals) {
    uint start_pos = 0;
    uint start_vals_pos = 0;
    uint i = 0, pg = 0;

    int c_x = item / Ny;
    int c_y = item % Ny;

    uint root_item;
    int cc_x, cc_y;

    for (; i < lap_len_2nd; ++ i) {
        cc_x = c_x; cc_y = c_y;

        cc_x += dx[i];
        cc_y += dy[i];

        cc_x = max(cc_x, 0);      // smch: may be changed to mod operation
        cc_x = min(cc_x, Nx-1);
        cc_y = max(cc_y, 0);
        cc_y = min(cc_y, Ny-1);

        // pd = inv.item2pd[((uint)cc_x)*Ny + ((uint)cc_y)];
        // root_item = inv.d_item(inv.find_(pd));
        root_item = ((uint)cc_x)*Ny + ((uint)cc_y);

        start_pos = 0;
        start_vals_pos = 0;

        for (pg = 0; pg < n_channels; ++ pg) {
            vals[start_vals_pos + i] = value_table[root_item + start_pos];
            start_pos += size;
            start_vals_pos += lap_len_2nd;
        }
    }
}


void NanovoidOneBackNormal::forward_one_step(valueType *vals, uint c, valueType *new_v) {

    //  ensure non zero
    valueType energy_v = std::abs(p.energy_v0) + 0.001;
    valueType energy_i = std::abs(p.energy_i0) + 0.001;
    valueType kBT = std::abs(p.kBT0) + 0.001;
    valueType kappa_v = std::abs(p.kappa_v0) + 0.001;
    valueType kappa_i = std::abs(p.kappa_i0) + 0.001;
    valueType kappa_eta = std::abs(p.kappa_eta0) + 0.001;
    valueType r_bulk = std::abs(p.r_bulk0) + 0.001;
    valueType r_surf = std::abs(p.r_surf0) + 0.001;

    valueType p_casc = std::abs(p.p_casc0) + 0.001;
    valueType bias = std::abs(p.bias0) + 0.001;
    valueType vg = std::abs(p.vg0) + 0.001;
    valueType diff_v = std::abs(p.diff_v0) + 0.001;
    valueType diff_i = std::abs(p.diff_i0) + 0.001;
    valueType L = std::abs(p.L0) + 0.001;

    // not considering previous frame value
    // accumulate_weight_derivative(vals, c);

    // this is to calculate cv ci and eta in previous frame
    valueType back_vals[this->vals_len / 2];
    forward_one_step_vals(vals, back_vals);
    new_v[c] = back_vals[0];                               // should ensure no zero here
    new_v[c + num_items] = back_vals[lap_len_2nd];         // should ensure no zero here
    new_v[c + num_items * 2] = back_vals[lap_len_2nd * 2]; // should ensure no zero here

    // consider previous frame value
    valueType back_vals_full[this->vals_len];
    std::memcpy(back_vals_full, vals, this->vals_len * sizeof(valueType));
    // std::memcpy(back_vals_full, back_vals, this->vals_len / 2 * sizeof(valueType));
    // for (int i = 0; i < this->vals_len; ++ i) {
    //     if (i < (this->vals_len / 2)) {
    //         back_vals_full[i] = back_vals[i]; 
    //     }
    //     else {

    //     }
    // }
    back_vals_full[0] = back_vals[0];
    back_vals_full[lap_len_2nd] = back_vals[lap_len_2nd];
    back_vals_full[lap_len_2nd * 2] = back_vals[lap_len_2nd * 2];
    valueType cv_diff = back_vals[0] - vals[0];
    valueType ci_diff = back_vals[lap_len_2nd] - vals[lap_len_2nd];
    valueType eta_diff = back_vals[lap_len_2nd * 2] - vals[lap_len_2nd * 2];
    for (uint i = 0; i < lap_len_2nd; ++i) {
        back_vals_full[i] = vals[i] + cv_diff;
        back_vals_full[lap_len_2nd + i] = vals[lap_len_2nd + i] + ci_diff;
        back_vals_full[lap_len_2nd * 2 + i] = vals[lap_len_2nd * 2 + i] + eta_diff;
    }
    accumulate_weight_derivative(back_vals_full, c);

    valueType dt = 2e-2;
    valueType mv = diff_v * back_vals[0] / kBT; // detect division by zero
    valueType mi = diff_i * back_vals[lap_len_2nd] / kBT; // detect division by zero
    valueType Q = dt * mv;
    valueType P = dt * mi;
    valueType R = dt * (-L) * N;
    valueType cv = back_vals[0];
    valueType ci = back_vals[lap_len_2nd];
    valueType eta = back_vals[lap_len_2nd * 2];

    // ensure non zero
    if (cv < 1e-6)
        cv = 1e-6;

    if (ci < 1e-6)
        ci = 1e-6;

    if (eta < 1e-6)
        eta = 1e-6;

    valueType back_dloss[this->vals_len / 2];

    // ensure non zero
    valueType one_cv_ci = 1 - cv - ci;
    if (one_cv_ci < 1e-6)
        one_cv_ci = 1e-6;

    // compute dloss_dcv
    //      compute dloss_dcv_dcv_dcv
    valueType dloss_dcv[lap_len_1st];
    for (uint i = 0; i < lap_len_1st; ++i) {
        dloss_dcv[i] = vals[this->vals_len / 2 + i];
    }
    valueType QlapDlossDcv = Q * inner_product(dloss_dcv, lapw, lap_len_1st);
    valueType dloss_dcv_dcv_dcv = vals[vals_len / 2] + QlapDlossDcv * ((eta - 1) * (eta - 1) * kBT *
                                                                   (1 / cv + 1 / one_cv_ci) +
                                                                   2 * eta * eta);
    dloss_dcv_dcv_dcv += kappa_v / 2 * Q * Q * inner_product(vals + vals_len / 2, laplapw, lap_len_2nd);

    //      compute dloss_dci_dci_dcv
    valueType dloss_dci[lap_len_1st];
    for (uint i = 0; i < lap_len_1st; ++i) {
        dloss_dci[i] = vals[this->vals_len / 2 + lap_len_2nd + i];
    }
    valueType PlapDlossDci = P * inner_product(dloss_dci, lapw, lap_len_1st);
    valueType dloss_dci_dci_dcv = PlapDlossDci * ((eta - 1) * (eta - 1) * kBT *
                                                  1 / one_cv_ci);

    //      compute dloss_deta_deta_dcv
    valueType dloss_deta_deta_dcv = R * vals[this->vals_len / 2 + lap_len_2nd * 2];
    dloss_deta_deta_dcv *= (
            2 * (eta - 1) * (energy_v + kBT * (log_with_mask_single(cv, EPS) - log_with_mask_single(one_cv_ci, EPS))) +
            2 * eta * (2 * (cv - 1)));

    // final assign
    new_v[c + num_items * 3] = dloss_dcv_dcv_dcv + dloss_dci_dci_dcv + dloss_deta_deta_dcv;

    // compute dloss_dci
    //      compute dloss_dci_dci_dci
    // valueType PlapDlossDci = P * inner_product(dloss_dci, lapw, lap_len_1st);
    valueType dloss_dci_dci_dci = vals[vals_len / 2 + lap_len_2nd] + PlapDlossDci * ((eta - 1) * (eta - 1) * kBT *
                                                                                 (1 / ci + 1 / one_cv_ci) +
                                                                                 2 * eta * eta);
    dloss_dci_dci_dci += kappa_i / 2 * P * P * inner_product(vals + vals_len / 2 + lap_len_2nd, laplapw, lap_len_2nd);

    //      compute dloss_dcv_dcv_dci
    valueType dloss_dcv_dcv_dci = QlapDlossDcv * ((eta - 1) * (eta - 1) * kBT *
                                                  1 / one_cv_ci);

    //      compute dloss_deta_deta_dci
    valueType dloss_deta_deta_dci = R * vals[this->vals_len / 2 + lap_len_2nd * 2];
    dloss_deta_deta_dci *= (
            2 * (eta - 1) * (energy_i + kBT * (log_with_mask_single(ci, EPS) - log_with_mask_single(one_cv_ci, EPS))) +
            2 * eta * 2 * ci);

    // final assign
    new_v[c + num_items * 4] = dloss_dci_dci_dci + dloss_dcv_dcv_dci + dloss_deta_deta_dci;

    // compute dloss_deta
    //      compute dloss_dcv_dcv_deta
    valueType dloss_dcv_dcv_deta = QlapDlossDcv * (2 * (eta - 1) * (energy_v + kBT * (log_with_mask_single(cv, EPS) -
                                                                                      log_with_mask_single(one_cv_ci,
                                                                                                           EPS))) +
                                                   2 * eta * 2 * (cv - 1));

    //      compute dloss_dci_dci_deta
    valueType dloss_dci_dci_deta = PlapDlossDci * (2 * (eta - 1) * (energy_i + kBT * (log_with_mask_single(ci, EPS) -
                                                                                      log_with_mask_single(one_cv_ci,
                                                                                                           EPS))) +
                                                   2 * eta * 2 * ci);
    //      compute dloss_deta_deta_deta
    valueType dloss_deta[lap_len_1st];
    for (uint i = 0; i < lap_len_1st; ++i) {
        dloss_deta[i] = vals[this->vals_len / 2 + lap_len_2nd * 2 + i];
    }

    valueType dloss_deta_deta_deta = vals[vals_len / 2 + lap_len_2nd * 2];
    dloss_deta_deta_deta += R * vals[vals_len / 2 + lap_len_2nd * 2] * 2 *
                            (energy_v * cv + energy_i * ci + kBT * (cv * log_with_mask_single(cv, EPS) 
                            + ci * log_with_mask_single(ci, EPS) + one_cv_ci * log_with_mask_single(one_cv_ci, EPS)) + (cv - 1) * (cv - 1) + ci * ci);

    dloss_deta_deta_deta -= R * kappa_eta * inner_product(dloss_deta, lapw, lap_len_1st);

    // final assign
    new_v[c + num_items * 5] = dloss_dcv_dcv_deta + dloss_dci_dci_deta + dloss_deta_deta_deta;

    if (debug_on) {
        if (isnan(new_v[c + num_items * 3]) || isnan(new_v[c + num_items * 4]) || isnan(new_v[c + num_items * 5])) {
            fflush(stdout);
            cout << "detect nan value" << endl;
            cout << "vals: ";
            for (int i = 0; i < vals_len; ++i)
                cout << vals[i] << ", ";
            cout << endl;
        }
    }
}

// output cv ci eta is within range [1e-6, 1.0]
void NanovoidOneBackNormal::forward_one_step_vals(valueType *vals, valueType *new_v) {
    valueType energy_v = std::abs(p.energy_v0) + 0.001;
    valueType energy_i = std::abs(p.energy_i0) + 0.001;
    valueType kBT = std::abs(p.kBT0) + 0.001;
    valueType kappa_v = std::abs(p.kappa_v0) + 0.001;
    valueType kappa_i = std::abs(p.kappa_i0) + 0.001;
    valueType kappa_eta = std::abs(p.kappa_eta0) + 0.001;
    valueType r_bulk = std::abs(p.r_bulk0) + 0.001;
    valueType r_surf = std::abs(p.r_surf0) + 0.001;

    valueType p_casc = std::abs(p.p_casc0) + 0.001;
    valueType bias = std::abs(p.bias0) + 0.001;
    valueType vg = std::abs(p.vg0) + 0.001;
    valueType diff_v = std::abs(p.diff_v0) + 0.001;
    valueType diff_i = std::abs(p.diff_i0) + 0.001;
    valueType L = std::abs(p.L0) + 0.001;

    // compute cv, ci
    valueType h_dfs_dcv[lap_len_1st];
    valueType h_dfs_dci[lap_len_1st];

    // construct h_dfs_dcv, h_dfs_dci
    for (uint i = 0; i < lap_len_1st; i++) {
        h_dfs_dcv[i] = 1.0;
        h_dfs_dci[i] = 1.0;

        //        h_dfs_dcv[i] = (vals[lap_len_2nd * 2 + i] - 1) * (vals[lap_len_2nd * 2 + i] - 1); // (eta-1)**2
        //        h_dfs_dci[i] = h_dfs_dcv[i];

        valueType log_cv = log_with_mask_single(vals[i], EPS);
        valueType log_ci = log_with_mask_single(vals[lap_len_2nd + i], EPS);
        valueType log_1_cv_ci = log_with_mask_single(1 - vals[i] - vals[i + lap_len_2nd], EPS);

        h_dfs_dcv[i] = h_dfs_dcv[i] * (energy_v + kBT * (log_cv - log_1_cv_ci));
        h_dfs_dci[i] = h_dfs_dci[i] * (energy_i + kBT * (log_ci - log_1_cv_ci));
        if ((1 - vals[i] - vals[i + lap_len_2nd]) < EPS) {
            h_dfs_dcv[i] = 0;
            h_dfs_dci[i] = 0;
        }

        h_dfs_dcv[i] = h_dfs_dcv[i] * (vals[lap_len_2nd * 2 + i] - 1) * (vals[lap_len_2nd * 2 + i] - 1); // (eta-1)**2
        h_dfs_dci[i] = h_dfs_dci[i] * (vals[lap_len_2nd * 2 + i] - 1) * (vals[lap_len_2nd * 2 + i] - 1);
    }

    valueType j_dfv_dcv[lap_len_1st];
    valueType j_dfv_dci[lap_len_1st];

    for (uint i = 0; i < lap_len_1st; i++) {
        j_dfv_dcv[i] = vals[lap_len_2nd * 2 + i] * vals[lap_len_2nd * 2 + i]; // eta**2
        j_dfv_dci[i] = j_dfv_dcv[i];

        j_dfv_dcv[i] = j_dfv_dcv[i] * 2 * (vals[i] - 1);
        j_dfv_dci[i] = j_dfv_dci[i] * 2 * vals[lap_len_2nd + i];
    }

    valueType dt = 2e-2;
    valueType mv = diff_v * vals[0] / kBT;
    valueType mi = diff_i * vals[lap_len_2nd] / kBT;

    valueType dt_mv_lap_h_dfs_dcv = dt * mv * inner_product(h_dfs_dcv, lapw, lap_len_1st);
    valueType dt_mv_lap_j_dfv_dcv = dt * mv * inner_product(j_dfv_dcv, lapw, lap_len_1st);
    valueType dt_mv_lap_lap_cv = -dt * mv * inner_product(vals, laplapw, lap_len_2nd);

    new_v[0] = vals[0] - (dt_mv_lap_h_dfs_dcv + dt_mv_lap_j_dfv_dcv + kappa_v * dt_mv_lap_lap_cv);

    valueType dt_mi_lap_h_dfs_dci = dt * mi * inner_product(h_dfs_dci, lapw, lap_len_1st);
    valueType dt_mi_lap_j_dfv_dci = dt * mi * inner_product(j_dfv_dci, lapw, lap_len_1st);
    valueType dt_mi_lap_lap_ci = -dt * mi * inner_product(vals + lap_len_2nd, laplapw, lap_len_2nd);

    new_v[0 + lap_len_2nd] =
            vals[lap_len_2nd] - (dt_mi_lap_h_dfs_dci + dt_mi_lap_j_dfv_dci + kappa_i * dt_mi_lap_lap_ci);

    // compute eta
    // fs
    valueType fs = energy_v * vals[0] + energy_i * vals[lap_len_2nd];
    fs = fs + kBT * (vals[0] * log_with_mask_single(vals[0], EPS));
    fs = fs + kBT * (vals[lap_len_2nd] * log_with_mask_single(vals[lap_len_2nd], EPS));
    fs = fs + kBT * ((1 - vals[0] - vals[lap_len_2nd]) * log_with_mask_single(1 - vals[0] - vals[lap_len_2nd], EPS));
    if ((1 - vals[0] - vals[lap_len_2nd]) < EPS) {
        fs = 0;
    }
    // fv
    valueType fv = (vals[0] - 1) * (vals[0] - 1) + vals[lap_len_2nd] * vals[lap_len_2nd];

    valueType dF_deta = N * (fs * 2 * (vals[lap_len_2nd * 2] - 1) + fv * 2 * vals[lap_len_2nd * 2] -
                             kappa_eta * inner_product(vals + lap_len_2nd * 2, lapw, lap_len_1st));

    // eta_0 = eta_1 - .... 
    new_v[0 + lap_len_2nd * 2] = vals[lap_len_2nd * 2] - dt * (-L) * dF_deta;

    if (std::signbit(new_v[0])) {
        new_v[0] = 1e-6;
    }

    if (std::signbit(new_v[0 + lap_len_2nd])) {
        new_v[0 + lap_len_2nd] = 1e-6;
    }

    if (std::signbit(new_v[0 + lap_len_2nd * 2])) {
        new_v[0 + lap_len_2nd * 2] = 1e-6;
    }

    if (new_v[0] >= 1.0) {
        new_v[0] = 1.0;
    }

    if (new_v[0 + lap_len_2nd] >= 1.0) {
        new_v[0 + lap_len_2nd] = 1.0;
    }

    if (new_v[0 + lap_len_2nd * 2] >= 1.0) {
        new_v[0 + lap_len_2nd * 2] = 1.0;
    }
}



void NanovoidOneBackNormal::log_with_mask(valueType *mat, valueType eps, uint len) {
    for (uint i = 0; i < len; i++) {
        if (mat[i] < eps) {
            mat[i] = eps;
        }
        mat[i] = log(mat[i]);
    }
}


valueType NanovoidOneBackNormal::log_with_mask_single(valueType p, valueType eps) {
    if (p < eps) {
        p = eps;
    }
    return log(p);
}


void NanovoidOneBackNormal::masked_fill(valueType *mat, int *mask, valueType eps, uint len) {
    for (uint i = 0; i < len; i++) {
        if (mask[i] == 1) {
            mat[i] = eps;
        }
    }
}


void NanovoidOneBackNormal::encode_from_img(valueType ***img, valueType ***dloss) {

    Coordinate2d3c c(0, 0);
    for (c.x = 0; c.x < Nx; ++c.x) {
        for (c.y = 0; c.y < Ny; ++c.y) {
            uint item_1 = c.to_item_c1(Nx, size);
            old_v[item_1] = img[c.x][c.y][0];                   // cv
            old_v[item_1 + num_items] = img[c.x][c.y][1];       // ci
            old_v[item_1 + num_items * 2] = img[c.x][c.y][2];   // eta
            old_v[item_1 + num_items * 3] = dloss[c.x][c.y][0]; // dloss_dcv
            old_v[item_1 + num_items * 4] = dloss[c.x][c.y][1]; // dloss_dci
            old_v[item_1 + num_items * 5] = dloss[c.x][c.y][2]; // dloss_deta
        }
    }

    if (debug_on) {
        printf("after assign old_v\n");
        fflush(stdout);
    }

}


valueType ***NanovoidOneBackNormal::decode_to_img() {
    Coordinate2d3c c(0, 0);

    valueType ***mtx = new valueType **[Nx];
    for (c.x = 0; c.x < Nx; c.x++) {
        valueType **row = new valueType *[Ny];
        for (c.y = 0; c.y < Ny; c.y++) {
            uint item = c.to_item_c1(Nx, num_items);
            // uint item_pd = inv.item2pd[item];
            // uint root = inv.find_(item_pd);
            uint root_item = item;
            valueType *channel_arr = new valueType[n_channels - 3];
            for (uint channel = 0; channel < (n_channels - 3); channel++) {
                channel_arr[channel] = old_v[root_item + num_items * channel];
                // printf("dealing x: %d, y: %d, c: %d\n", c.x, c.y, channel);
            }
            row[c.y] = channel_arr;
        }
        mtx[c.x] = row;
    }
    return mtx;
}


void NanovoidOneBackNormal::accumulate_weight_derivative(valueType *vals, uint c) {
    valueType energy_v = std::abs(p.energy_v0) + 0.001;
    valueType energy_i = std::abs(p.energy_i0) + 0.001;
    valueType kBT = std::abs(p.kBT0) + 0.001;
    valueType kappa_v = std::abs(p.kappa_v0) + 0.001;
    valueType kappa_i = std::abs(p.kappa_i0) + 0.001;
    valueType kappa_eta = std::abs(p.kappa_eta0) + 0.001;
    valueType r_bulk = std::abs(p.r_bulk0) + 0.001;
    valueType r_surf = std::abs(p.r_surf0) + 0.001;

    valueType p_casc = std::abs(p.p_casc0) + 0.001;
    valueType bias = std::abs(p.bias0) + 0.001;
    valueType vg = std::abs(p.vg0) + 0.001;
    valueType diff_v = std::abs(p.diff_v0) + 0.001;
    valueType diff_i = std::abs(p.diff_i0) + 0.001;
    valueType L = std::abs(p.L0) + 0.001;
    
    valueType dt = 2e-2;
    valueType mv = diff_v * vals[0] / kBT; // diffv * cv / kBT
    valueType mi = diff_i * vals[lap_len_2nd] / kBT; // diffi * ci / kBT
    // valueType L = std::abs(p.L0);
    // valueType energy_v = std::abs(p.energy_v0);
    // valueType kBT = std::abs(p.kBT0);
    // valueType kappa_v = std::abs(p.kappa_v0);
    // valueType energy_i = std::abs(p.energy_i0);
    // valueType kappa_i = std::abs(p.kappa_i0);
    // valueType kappa_eta = std::abs(p.kappa_eta0);

    // get size of this bucket
    // uint root_pd = 1; //inv.item2pd[c];
    uint bucket_size = 1; //inv.d_size(root_pd);
    // uint bucket_size = 1;

    // 1 - cv - ci
    valueType one_cv_ci = 1 - vals[0] - vals[lap_len_2nd];
    if (one_cv_ci < 1e-6)
        one_cv_ci = 1e-6;

    valueType cv = vals[0];
    valueType ci = vals[lap_len_2nd];
    valueType eta = vals[lap_len_2nd * 2];

    // just for debugging
    // vals[lap_len_2nd * 3] = 1.0;
    // vals[lap_len_2nd * 4] = 1.0;
    // vals[lap_len_2nd * 5] = 1.0;

    // energy v
    valueType eta_1_sq[lap_len_1st];
    for (uint i = 0; i < lap_len_1st; ++i) {
        eta_1_sq[i] = (vals[lap_len_2nd * 2 + i] - 1) * (vals[lap_len_2nd * 2 + i] - 1);
    }
    valueType dcv_dev = dt * mv * inner_product(eta_1_sq, lapw, lap_len_1st);
    valueType dci_dev = 0.0;
    valueType deta_dev = dt * (-L) * N * 2 * (vals[lap_len_2nd * 2] - 1) * vals[0];
    if (p.energy_v0 <= 0.0) {
        dcv_dev = - dcv_dev;
        dci_dev = - dci_dev;
        deta_dev = - deta_dev;
    }

    // dp.energy_v0 += (cv <= 1e-6 || cv >= 1.0)?0:(bucket_size * vals[lap_len_2nd * 3] * dcv_dev);
    // dp.energy_v0 += (ci <= 1e-6 || ci >= 1.0)?0:(bucket_size * vals[lap_len_2nd * 4] * dcv_dev);
    // dp.energy_v0 += (eta <= 1e-6 || eta >= 1.0)?0:(bucket_size * vals[lap_len_2nd * 5] * dcv_dev);
    // dp.energy_v0 += bucket_size * (vals[lap_len_2nd * 3] * dcv_dev + vals[lap_len_2nd * 4] * dci_dev
    //         + vals[lap_len_2nd * 5] * deta_dev);
    // dp.energy_v0 += bucket_size * (1 * dcv_dev + 1 * dci_dev+ 1 * deta_dev);

    // energy i
    valueType dcv_dei = 0.0;
    valueType dci_dei = dt * mi * inner_product(eta_1_sq, lapw, lap_len_1st);
    valueType deta_dei = dt * (-L) * N * 2 * (vals[lap_len_2nd * 2] - 1) * vals[lap_len_2nd];
    if (p.energy_i0 <= 0) {
        dcv_dei = - dcv_dei;
        dci_dei = - dci_dei;
        deta_dei = - deta_dei;
    }

    // dp.energy_i0 += (cv <= 1e-6 || cv >= 1.0)?0:(bucket_size * vals[lap_len_2nd * 3] * dcv_dei);
    // dp.energy_i0 += (ci <= 1e-6 || ci >= 1.0)?0:(bucket_size * vals[lap_len_2nd * 4] * dci_dei);
    // dp.energy_i0 += (eta <= 1e-6 || eta >= 1.0)?0:(bucket_size * vals[lap_len_2nd * 5] * deta_dei);
    // dp.energy_i0 += bucket_size * (vals[lap_len_2nd * 3] * dcv_dei + vals[lap_len_2nd * 4] * dci_dei
    //         + vals[lap_len_2nd * 5] * deta_dei);
    // dp.energy_i0 += bucket_size * (dcv_dei + dci_dei + deta_dei);

    // kBT
    valueType eta_1_sq_cv[lap_len_1st];
    valueType eta_1_sq_ci[lap_len_1st];
    for (uint i = 0; i < lap_len_1st; ++i) {
        eta_1_sq_cv[i] = eta_1_sq[i] * (log_with_mask_single(vals[0], EPS) - log_with_mask_single(one_cv_ci, EPS));
        eta_1_sq_ci[i] =
                eta_1_sq[i] * (log_with_mask_single(vals[lap_len_2nd], EPS) - log_with_mask_single(one_cv_ci, EPS));
    }
    valueType dcv_dkBT = dt * mv * inner_product(eta_1_sq_cv, lapw, lap_len_1st);
    valueType dci_dkBT = dt * mv * inner_product(eta_1_sq_ci, lapw, lap_len_1st);
    valueType deta_dkBT = dt * (-L) * N * 2 * (vals[lap_len_2nd * 2] - 1) *
                          (vals[0] * log_with_mask_single(vals[0], EPS) +
                           vals[lap_len_2nd] * log_with_mask_single(vals[lap_len_2nd], EPS) +
                           one_cv_ci * log_with_mask_single(one_cv_ci, EPS));
    if (p.kBT0 <= 0) {
        dcv_dkBT = - dcv_dkBT;
        dci_dkBT = - dci_dkBT;
        deta_dkBT = - deta_dkBT;
    }

    // dp.kBT0 += (cv <= 1e-6 || cv >= 1.0)?0:(bucket_size * vals[lap_len_2nd * 3] * dcv_dkBT);
    // dp.kBT0 += (ci <= 1e-6 || ci >= 1.0)?0:(bucket_size * vals[lap_len_2nd * 4] * dci_dkBT);
    // dp.kBT0 += (eta <= 1e-6 || eta >= 1.0)?0:(bucket_size * vals[lap_len_2nd * 5] * deta_dkBT);
    // dp.kBT0 += bucket_size * (vals[lap_len_2nd * 3] * dcv_dkBT + vals[lap_len_2nd * 4] * dci_dkBT
    //         + vals[lap_len_2nd * 5] * deta_dkBT);
    // dp.kBT0 += bucket_size * ( dcv_dkBT + dci_dkBT + deta_dkBT);

    // kappa v
    valueType dcv_dkappa_v = -dt * mv * inner_product(vals, laplapw, lap_len_2nd);
    valueType dci_dkappa_v = 0.0;
    valueType deta_dkappa_v = 0.0;
    if (p.kappa_v0 <= 0) {
        dcv_dkappa_v = - dcv_dkappa_v;
        dci_dkappa_v = - dci_dkappa_v;
        deta_dkappa_v = - deta_dkappa_v;
    }

    // dp.kappa_v0 += (cv <= 1e-6 || cv >= 1.0)?0:(bucket_size * vals[lap_len_2nd * 3] * dcv_dkappa_v);
    // dp.kappa_v0 += (ci <= 1e-6 || ci >= 1.0)?0:(bucket_size * vals[lap_len_2nd * 4] * dci_dkappa_v);
    // dp.kappa_v0 += (eta <= 1e-6 || eta >= 1.0)?0:(bucket_size * vals[lap_len_2nd * 5] * deta_dkappa_v);
    // dp.kappa_v0 += bucket_size * (vals[lap_len_2nd * 3] * dcv_dkappa_v + vals[lap_len_2nd * 4] * dci_dkappa_v
    //         + vals[lap_len_2nd * 5] * deta_dkappa_v);
    // dp.kappa_v0 += bucket_size * (dcv_dkappa_v + dci_dkappa_v + deta_dkappa_v);

    // kappa i
    valueType dcv_dkappa_i = 0.0;
    valueType dci_dkappa_i = -dt * mi * inner_product(vals + lap_len_2nd, laplapw, lap_len_2nd);
    valueType deta_dkappa_i = 0.0;
    if (p.kappa_i0 <= 0) {
        dcv_dkappa_i = - dcv_dkappa_i;
        dci_dkappa_i = - dci_dkappa_i;
        deta_dkappa_i = - deta_dkappa_i;
    }

    // dp.kappa_i0 += (cv <= 1e-6 || cv >= 1.0)?0:(bucket_size * vals[lap_len_2nd * 3] * dcv_dkappa_i);
    // dp.kappa_i0 += (ci <= 1e-6 || ci >= 1.0)?0:(bucket_size * vals[lap_len_2nd * 4] * dci_dkappa_i);
    // dp.kappa_i0 += (eta <= 1e-6 || eta >= 1.0)?0:(bucket_size * vals[lap_len_2nd * 5] * deta_dkappa_i);
    // dp.kappa_i0 += bucket_size * (vals[lap_len_2nd * 3] * dcv_dkappa_i + vals[lap_len_2nd * 4] * dci_dkappa_i
    //         + vals[lap_len_2nd * 5] * deta_dkappa_i);
    // dp.kappa_i0 += bucket_size * (dcv_dkappa_i + dci_dkappa_i + deta_dkappa_i);

    // kappa eta
    valueType dcv_dkappa_eta = 0.0;
    valueType dci_dkappa_eta = 0.0;
    valueType deta_dkappa_eta = dt * L * N * inner_product(vals + lap_len_2nd * 2, lapw, lap_len_1st);
    if (p.kappa_eta0 <= 0) {
        dcv_dkappa_eta = - dcv_dkappa_eta;
        dci_dkappa_eta = - dci_dkappa_eta;
        deta_dkappa_eta = - deta_dkappa_eta;
    }

    // dp.kappa_eta0 += (cv <= 1e-6 || cv >= 1.0)?0:(bucket_size * vals[lap_len_2nd * 3] * dcv_dkappa_eta);
    // dp.kappa_eta0 += (ci <= 1e-6 || ci >= 1.0)?0:(bucket_size * vals[lap_len_2nd * 4] * dci_dkappa_eta);
    // dp.kappa_eta0 += (eta <= 1e-6 || eta >= 1.0)?0:(bucket_size * vals[lap_len_2nd * 5] * deta_dkappa_eta);
    // dp.kappa_eta0 += bucket_size * (vals[lap_len_2nd * 3] * dcv_dkappa_eta + vals[lap_len_2nd * 4] * dci_dkappa_eta
    //         + vals[lap_len_2nd * 5] * deta_dkappa_eta);
    // dp.kappa_eta0 += bucket_size * (dcv_dkappa_eta + dci_dkappa_eta + deta_dkappa_eta);

    // diff v
    valueType eta_1_sq_eta_sq_cv[lap_len_1st];
    for (uint i = 0; i < lap_len_1st; ++i) {
        valueType dfs_dcv = (energy_v + kBT * (log_with_mask_single(vals[0], EPS) - log_with_mask_single(one_cv_ci, EPS)));
        if (one_cv_ci < 1e-6) {
            dfs_dcv = 0;
        }
        eta_1_sq_eta_sq_cv[i] = eta_1_sq[i] * dfs_dcv + eta * eta * 2 * (vals[0] - 1);
    }
    valueType dcv_ddiff_v = -dt * cv / kBT * (inner_product(eta_1_sq_eta_sq_cv, lapw, lap_len_1st) -
                                                  kappa_v * inner_product(vals, laplapw, lap_len_2nd));
    valueType dci_ddiff_v = 0.0;
    valueType deta_ddiff_v = 0.0;
    if (p.diff_v0 <= 0) {
        dcv_ddiff_v = - dcv_ddiff_v;
        dci_ddiff_v = - dci_ddiff_v;
        deta_ddiff_v = - deta_ddiff_v;
    }

    // dp.diff_v0 += (cv <= 1e-6 || cv >= 1.0)?0:(bucket_size * vals[lap_len_2nd * 3] * dcv_ddiff_v);
    // dp.diff_v0 += (ci <= 1e-6 || ci >= 1.0)?0:(bucket_size * vals[lap_len_2nd * 4] * dci_ddiff_v);
    // dp.diff_v0 += (eta <= 1e-6 || eta >= 1.0)?0:(bucket_size * vals[lap_len_2nd * 5] * deta_ddiff_v);
    // dp.diff_v0 += bucket_size * (vals[lap_len_2nd * 3] * dcv_ddiff_v + vals[lap_len_2nd * 4] * dci_ddiff_v
    //         + vals[lap_len_2nd * 5] * deta_ddiff_v);
    // dp.diff_v0 += bucket_size * (dcv_ddiff_v + dci_ddiff_v + deta_ddiff_v);
    

    // diff i
    valueType dcv_ddiff_i = 0.0;
    valueType eta_1_sq_eta_sq_ci[lap_len_1st];
    for (uint i = 0; i < lap_len_1st; ++i) {
        valueType dfs_dci = (energy_i + kBT * (log_with_mask_single(vals[lap_len_2nd], EPS) - log_with_mask_single(one_cv_ci, EPS)));
        if (one_cv_ci < 1e-6) {
            dfs_dci = 0;
        }
        eta_1_sq_eta_sq_ci[i] = eta_1_sq[i] * dfs_dci + eta * eta * 2 * vals[lap_len_2nd];
    }
    valueType dci_ddiff_i = -dt * vals[lap_len_2nd] / kBT * (inner_product(eta_1_sq_eta_sq_ci, lapw, lap_len_1st) -
                                                            kappa_i *
                                                            inner_product(vals + lap_len_2nd, laplapw, lap_len_2nd));
    valueType deta_ddiff_i = 0.0;
    if (p.diff_i0 <= 0) {
        dcv_ddiff_i = - dcv_ddiff_i;
        dci_ddiff_i = - dci_ddiff_i;
        deta_ddiff_i = - deta_ddiff_i;
    }

    // dp.diff_i0 += (cv <= 1e-6 || cv >= 1.0)?0:(bucket_size * vals[lap_len_2nd * 3] * dcv_ddiff_i);
    // dp.diff_i0 += (ci <= 1e-6 || ci >= 1.0)?0:(bucket_size * vals[lap_len_2nd * 4] * dci_ddiff_i);
    // dp.diff_i0 += (eta <= 1e-6 || eta >= 1.0)?0:(bucket_size * vals[lap_len_2nd * 5] * deta_ddiff_i);
    // dp.diff_i0 += bucket_size * (vals[lap_len_2nd * 3] * dcv_ddiff_i + vals[lap_len_2nd * 4] * dci_ddiff_i
    //         + vals[lap_len_2nd * 5] * deta_ddiff_i);
    //  dp.diff_i0 += bucket_size * (dcv_ddiff_i + dci_ddiff_i + deta_ddiff_i);

    // L
    valueType dcv_dL = 0.0;
    valueType dci_dL = 0.0;
    valueType fs = energy_v * vals[0] + energy_i * vals[lap_len_2nd] + kBT *
                                                                       (vals[0] * log_with_mask_single(vals[0], EPS) +
                                                                        vals[lap_len_2nd] *
                                                                        log_with_mask_single(vals[lap_len_2nd], EPS) +
                                                                        one_cv_ci *
                                                                        log_with_mask_single(one_cv_ci, EPS));
    if (one_cv_ci < 1e-6) {
        fs = 0;
    }
    valueType fv = (vals[0] - 1) * (vals[0] - 1) + vals[lap_len_2nd] * vals[lap_len_2nd];
    valueType deta_dL = -dt * N * (fs * 2 * (vals[lap_len_2nd * 2] - 1) + fv * 2 * vals[lap_len_2nd * 2] -
                                   kappa_eta * inner_product(vals + lap_len_2nd * 2, lapw, lap_len_1st));
    if (p.L0 <= 0) {
        dcv_dL = -dcv_dL;
        dci_dL = -dci_dL;
        deta_dL = -deta_dL;
    }

    // dp.L0 += (cv <= 1e-6 || cv >= 1.0)?0:(bucket_size * vals[lap_len_2nd * 3] * dcv_dL);
    // dp.L0 += (ci <= 1e-6 || ci >= 1.0)?0:(bucket_size * vals[lap_len_2nd * 4] * dci_dL);
    // dp.L0 += (eta <= 1e-6 || eta >= 1.0)?0:(bucket_size * vals[lap_len_2nd * 5] * deta_dL);
    // dp.L0 += bucket_size * (vals[lap_len_2nd * 3] * dcv_dL + vals[lap_len_2nd * 4] * dci_dL
    //         + vals[lap_len_2nd * 5] * deta_dL);
    // dp.L0 += bucket_size * (dcv_dL + dci_dL + deta_dL);

    // v2 integral accumulation
    // if (one_cv_ci <= 1e-6 || one_cv_ci >= 1.0) {
    //     // accumulate nothing
    // }
    // else {
        if (!(cv <= 1e-6 || cv >= 1.0)) {
            // cv terms
            // if (!(one_cv_ci <= 1e-6)) {
                dp.energy_v0 += bucket_size * vals[lap_len_2nd * 3] * dcv_dev;
                dp.energy_i0 += bucket_size * vals[lap_len_2nd * 3] * dcv_dei;
                dp.kBT0 += bucket_size * vals[lap_len_2nd * 3] * dcv_dkBT;
            // }
            dp.kappa_v0 += bucket_size * vals[lap_len_2nd * 3] * dcv_dkappa_v;
            dp.kappa_i0 += bucket_size * vals[lap_len_2nd * 3] * dcv_dkappa_i;
            dp.kappa_eta0 += bucket_size * vals[lap_len_2nd * 3] * dcv_dkappa_eta;
            dp.diff_v0 += bucket_size * vals[lap_len_2nd * 3] * dcv_ddiff_v;
            dp.diff_i0 += bucket_size * vals[lap_len_2nd * 3] * dcv_ddiff_i;
            dp.L0 += bucket_size * vals[lap_len_2nd * 3] * dcv_dL;
        }
        if (!(ci <= 1e-6 || ci >= 1.0)) {
            // ci terms
            // if (!(one_cv_ci <= 1e-6)) {
                dp.energy_v0 += bucket_size * vals[lap_len_2nd * 4] * dcv_dev;
                dp.energy_i0 += bucket_size * vals[lap_len_2nd * 4] * dci_dei;
                dp.kBT0 += bucket_size * vals[lap_len_2nd * 4] * dci_dkBT;
            // }
            dp.kappa_v0 += bucket_size * vals[lap_len_2nd * 4] * dci_dkappa_v;
            dp.kappa_i0 += bucket_size * vals[lap_len_2nd * 4] * dci_dkappa_i;
            dp.kappa_eta0 += bucket_size * vals[lap_len_2nd * 4] * dci_dkappa_eta;
            dp.diff_v0 += bucket_size * vals[lap_len_2nd * 4] * dci_ddiff_v;
            dp.diff_i0 += bucket_size * vals[lap_len_2nd * 4] * dci_ddiff_i;
            dp.L0 += bucket_size * vals[lap_len_2nd * 4] * dci_dL;
        }
        if (!(eta <= 1e-6 || eta >= 1.0)) {
            // eta terms
            // if (!(one_cv_ci <= 1e-6)) {
                dp.energy_v0 += bucket_size * vals[lap_len_2nd * 5] * dcv_dev;
                dp.energy_i0 += bucket_size * vals[lap_len_2nd * 5] * deta_dei;
                dp.kBT0 += bucket_size * vals[lap_len_2nd * 5] * deta_dkBT;
            // }
            dp.kappa_v0 += bucket_size * vals[lap_len_2nd * 5] * deta_dkappa_v;
            dp.kappa_i0 += bucket_size * vals[lap_len_2nd * 5] * deta_dkappa_i;
            dp.kappa_eta0 += bucket_size * vals[lap_len_2nd * 5] * deta_dkappa_eta;
            dp.diff_v0 += bucket_size * vals[lap_len_2nd * 5] * deta_ddiff_v;
            dp.diff_i0 += bucket_size * vals[lap_len_2nd * 5] * deta_ddiff_i;
            dp.L0 += bucket_size * vals[lap_len_2nd * 5] * deta_dL;
        }
    // }


    // if (!(cv <= 1e-6 || cv >= 1.0) || !(ci <= 1e-6 || ci >= 1.0) || !(eta <= 1e-6 || eta >= 1.0)) {
    //    printf("one pixel accumulate\n");
    //    this->print_derivative(); 
    // }  
}


void NanovoidOneBackNormal::print_derivative() {
    fflush(stdout);
    cout << "derivative of weight: " << endl;
    cout << "energy_v: " << dp.energy_v0 << endl;
    cout << "energy_i: " << dp.energy_i0 << endl;
    cout << "kBT: " << dp.kBT0 << endl;
    cout << "kappa_v: " << dp.kappa_v0 << endl;
    cout << "kappa_i: " << dp.kappa_i0 << endl;
    cout << "kappa_eta: " << dp.kappa_eta0 << endl;
    cout << "diff_v: " << dp.diff_v0 << endl;
    cout << "diff_i: " << dp.diff_i0 << endl;
    cout << "L: " << dp.L0 << endl;
}

void NanovoidOneBackNormal::next() {
    valueType vals[this->vals_len];
    for (int i = 0; i < this->Nx; i++) {
        for (int j = 0; j < this->Ny; j++) {
            grab_vals(i * this->Ny + j, old_v, vals);
            forward_one_step(vals, i * this->Ny + j, new_v);
            if ((i * this->Ny + j) % 100 == 0) {
                printf("step: %d\n", i * this->Ny + j);
            }
        }
    }
    std::memcpy(old_v, new_v, this->value_table_size * sizeof(valueType));
}




const int NanovoidOneBackNormal::dx[] = {0, 1, 0, -1, 0, 1, -1, 1, -1, 2, 0, -2, 0};
const int NanovoidOneBackNormal::dy[] = {0, 0, 1, 0, -1, 1, 1, -1, -1, 0, 2, 0, -2};
const valueType NanovoidOneBackNormal::laplapw[] = {20, -8, -8, -8, -8, 2, 2, 2, 2, 1, 1, 1, 1};
const valueType NanovoidOneBackNormal::lapw[] = {-4, 1, 1, 1, 1};